#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig


def is_mindie_turbo_supported() -> bool:
    try:
        import mindie_turbo  # noqa: F401
    except ImportError:
        return False
    
    return True


def example_quantization(model_name_or_path: str, tmp_path: str) -> None:

    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=model_name_or_path
    )

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_name_or_path,
        device_map="npu:0",
        torch_dtype="auto"
    ).eval()

    data_list = ["What's deep learning?"]
    dataset_calib = []
    for calib_data in data_list:
        inputs = tokenizer(calib_data, return_tensors='pt').to("npu:0")
    dataset_calib.append([inputs.data['input_ids']])

    anti_config = AntiOutlierConfig(anti_method="m2", dev_type="npu", dev_id=0)
    anti_outlier = AntiOutlier(model, calib_data=dataset_calib, cfg=anti_config)
    anti_outlier.process()

    disable_names = ['lm_head']
    for layer_index in range(24):
        disable_names.append(f'model.layers.{layer_index}.mlp.down_proj')

    quant_config = QuantConfig(
        a_bit=8,
        w_bit=8,
        disable_names=disable_names,
        dev_type='npu',
        dev_id=0,
        act_method=3,
        pr=1.0,
        w_sym=True,
        mm_tensor=False
    )

    calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level='L0')
    calibrator.run()

    # Currently, we need add config.json manualy for quantized weights generated by msmodelslim.
    # Following codes will be removed once msmodelslim can generate complete weights
    # except 'calibrator.save(tmp_path, save_type=["safe_tensor"])'.
    class EmptyModule(torch.nn.Module):
        def __init__(self) -> None:
            super(EmptyModule, self).__init__()

        def forward(self, x):
            return x

    calibrator.model.config.quantization_config = calibrator.quant_model_json_description.quant_model_description

    calibrator.save(tmp_path, save_type=["safe_tensor"])
    calibrator.model.save_pretrained(tmp_path, state_dict=EmptyModule().state_dict())
    tokenizer.save_pretrained(tmp_path)